# src/evaluation/utils.py
import os
import shutil
import torch
import logging
import json
from ast import literal_eval

def setup_logging(log_file_path):
    """
    Sets up logging to both a file and the console.
    """
    log_dir = os.path.dirname(log_file_path)
    os.makedirs(log_dir, exist_ok=True)

    # Use a unique name for the logger to avoid conflicts
    logger = logging.getLogger('MADExperimentLogger')
    logger.setLevel(logging.INFO)

    # Prevent adding handlers multiple times if the function is called repeatedly
    if not logger.handlers:
        # File handler
        file_handler = logging.FileHandler(log_file_path, mode='a', encoding='utf-8')
        file_handler.setLevel(logging.INFO)

        # Console handler
        console_handler = logging.StreamHandler()
        console_handler.setLevel(logging.INFO)

        formatter = logging.Formatter('%(asctime)s - [%(levelname)s] - %(message)s', datefmt='%Y-%m-%d %H:%M:%S')
        file_handler.setFormatter(formatter)
        console_handler.setFormatter(formatter)

        logger.addHandler(file_handler)
        logger.addHandler(console_handler)

    return logger

def save_checkpoint(state, is_best, output_dir, filename='checkpoint.pth.tar', best_filename='model_best.pth.tar'):
    """
    Saves a model checkpoint and copies it as the best model if it has the best performance.
    """
    os.makedirs(output_dir, exist_ok=True)
    filepath = os.path.join(output_dir, filename)
    torch.save(state, filepath)
    
    if is_best:
        best_filepath = os.path.join(output_dir, best_filename)
        shutil.copyfile(filepath, best_filepath)

def save_config_used(config_dict, file_path):
    """
    Saves the configuration used for an experiment to a YAML or JSON file.
    """
    try:
        with open(file_path, 'w', encoding='utf-8') as f:
            # Use json.dump for a structured and readable config file
            json.dump(config_dict, f, indent=4, ensure_ascii=False, default=str)
    except Exception as e:
        print(f"Error saving config to {file_path}: {e}")

def load_checkpoint(filepath, model, optimizer=None, device='cpu', logger=None, strict=False):
    """
    Loads model and optimizer states from a saved checkpoint.
    
    Args:
        filepath (str): Path to the checkpoint file.
        model (nn.Module): Model instance to load the state into.
        optimizer (torch.optim.Optimizer, optional): Optimizer instance to load the state into.
        device (str, optional): The device to map the loaded tensors to.
        logger (logging.Logger, optional): Logger for printing messages.
        strict (bool, optional): Whether to strictly enforce that the keys in state_dict match.
                                 Set to False for fine-tuning or transfer learning.
    """
    if not os.path.exists(filepath):
        if logger:
            logger.error(f"Checkpoint file not found: {filepath}")
        return None

    if logger:
        logger.info(f"Loading checkpoint file from: {filepath}")

    try:
        checkpoint = torch.load(filepath, map_location=device)
        
        # --- Core Logic: Extract state_dict based on the context ---
        state_dict_to_load = checkpoint.get('state_dict', checkpoint)

        if state_dict_to_load:
            # Use strict=False to flexibly load weights, ignoring mismatches
            incompatible_keys = model.load_state_dict(state_dict_to_load, strict=strict)
            if logger:
                # Log mismatched keys at the WARNING level
                if incompatible_keys.missing_keys:
                    logger.warning(f"Missing keys in model state_dict: {incompatible_keys.missing_keys}")
                if incompatible_keys.unexpected_keys:
                    logger.warning(f"Unexpected keys in model state_dict: {incompatible_keys.unexpected_keys}")
                logger.info(f"Model weights loaded successfully from: {filepath}")
        else:
            if logger: 
                logger.error("Could not find 'state_dict' in the checkpoint.")
            return None

        if optimizer and 'optimizer' in checkpoint:
            try:
                optimizer.load_state_dict(checkpoint['optimizer'])
                if logger: 
                    logger.info("Optimizer state loaded successfully.")
            except Exception as e:
                if logger: 
                    logger.warning(f"Failed to load optimizer state: {e}. The optimizer will be re-initialized.")

        return checkpoint

    except Exception as e:
        if logger:
            logger.error(f"An error occurred while loading the checkpoint: {e}", exc_info=True)
        return None